import torch
from torch import nn
import layer_norm as ln
#残差连接，缓解梯度爆炸和梯度消失的问题
class ResidualConnection(nn.Module):
    def __init__(self, dim_model, dropout_p):
        super().__init__()
        self.dropout = nn.Dropout(dropout_p)
        self.norm = ln.LayerNorm(dim_model)
    
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))